import torch
from torch import nn
import numpy as np

import torch.nn.functional as F

def get_flattened_metric(net, metric):
    grad_list = []
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            grad_list.append(metric(layer).flatten())
    flattened_grad = np.concatenate(grad_list)

    return flattened_grad


def get_grad_conflict(net, inputs, targets, loss_fn=F.cross_entropy):
    N = inputs.shape[0]
    batch_grad = []
    for i in range(N):
        net.zero_grad()
        outputs = net(inputs[[i]])
        assert isinstance(outputs, tuple)
        outputs = outputs[1]
        loss = loss_fn(outputs, targets[[i]])
        loss.backward()
        flattened_grad = get_flattened_metric(net, lambda l: l.weight.grad.data.cpu().numpy() if l.weight.grad is not None else torch.zeros_like(l.weight).cpu().numpy())
        batch_grad.append(flattened_grad)
    batch_grad = np.stack(batch_grad)
    direction_code = np.sign(batch_grad)
    direction_code = abs(direction_code.sum(axis=0))
    score = np.nanmean(direction_code)
    return score


def gradsign(train_loader, networks, train_mode=False, num_batch=-1, num_classes=100, verbose=False):
    device = torch.cuda.current_device()
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()

    inputs, targets = next(iter(train_loader))
    inputs = inputs.to(device)
    targets = targets.to(device)
    targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()

    network_gradsign = []
    for net in networks:
        network_gradsign.append(-1 * get_grad_conflict(net=net, inputs=inputs, targets=targets, loss_fn=F.cross_entropy))

    return network_gradsign